import os
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from PIL import Image
import numpy as np
from torchvision import datasets,transforms,models

data_path = 'F:\\datas\\test_imgs\\archive\\chest_xray\\'
class_name = ['NORMAL','PNEUMONIA']
def get_list_files(dirName):
    file_list = os.listdir(dirName)
    return file_list

file_normal_train = get_list_files(data_path +'train\\' + class_name[0])
file_pne_train = get_list_files(data_path + 'train\\' + class_name[1])

file_normal_test = get_list_files(data_path +'test\\' + class_name[0])
file_pne_test = get_list_files(data_path + 'test\\' + class_name[1])


print('Normal类别的训练样本数：{}'.format(len(file_normal_train)))
print('PNEUMONIA类别的训练样本数：{}'.format(len(file_pne_train)))
print('Normal类别的测试样本数：{}'.format(len(file_normal_test)))
print('PNEUMONIA类别的测试样本数：{}'.format(len(file_pne_test)))

rand_image_no = np.random.randint(0,len(file_normal_train))
img = data_path + "train\\NORMAL\\" + file_normal_train[rand_image_no]
img = mpimg.imread(img)
imgplot = plt.imshow(img)
# plt.show()

